#!/usr/bin/python

import code,pickle,sys,os,re,traceback,cPickle
import matplotlib
if "DISPLAY" not in os.environ: matplotlib.use("AGG")
else: matplotlib.use("GTK")
import random as pyrandom
from optparse import OptionParser
from pylab import *
from scipy import stats
import ocrolib
import heapq
from ocrolib import dbtables,quant,utils,gatedmodel,lru,docproc,Record,mlp

def log_progress(fmt,*args):
    sys.stderr.write(fmt%args)
    sys.stderr.write("\033[K\r")

parser = OptionParser(usage="""
%prog [options] input.db 

Apply a classifier to the characters in the database and store the classification
results in the database.

By default, stores the classification results in separate columns, pred and pcost.
The cost of the actual character in cls is stored in pocost.  An overall error
rate is returned.  If -e is given, only computes the error rate but stores nothing.

Alternatively, can update the cls and cost columsn in the database themselves
when the -C option is given.  This is useful, for example, on the output from
ocropus-extract-rseg or clustered output.
""")

parser.add_option("-D","--display",help="display the characters",action="store_true")
parser.add_option("-e","--error",help="only calculate error",action="store_true")
parser.add_option("-C","--setclass",help="set class, not pred",action="store_true")
parser.add_option("-N","--limit",help="max # samples",default=1000000000)
parser.add_option("-m","--model",help="model used for classification",default="default.model")
parser.add_option("-t","--table",help="table",default="chars")
(options,args) = parser.parse_args()

if len(args)!=1:
    print "usage: ..."
    sys.exit(1)

badcost = 9999.0
table = options.table

print "opening",args[0]

db = utils.chardb(args[0],options.table)
db.execute("pragma synchronous=0")
model = ocrolib.load_component(options.model)

target = next(db.execute("select count(*) from %s"%options.table))[0]

print "model",model

total = 0
errors = 0

if options.display: ion(); gray()

for row in db.execute("select * from %s"%(options.table,)):
    image = utils.blob2image(row.image)/255.0
    rel = docproc.rel_geo_normalize(row.rel)
    cls = row.cls
    outputs = model.coutputs(image,geometry=rel)
    outputs = [(c,-log(max(p,1e-6))) for c,p in outputs]
    outputs = sorted(outputs,key=lambda x:x[1])
    total += 1
    if len(outputs)<1:
        errors += 1
        pred = "~"
        pcost = badcost
        pocost = badcost
    else:
        pred,pcost = outputs[0]
        pocost = list(p for c,p in outputs if c==cls)
        pocost = pocost[0] if len(pocost)>0 else badcost
    if options.display: 
        clf()
        imshow(image)
        draw()
        print row.id,cls,pocost,pred,pcost,outputs[:3]
        raw_input()
    if options.setclass:
        db.execute("update chars set cls=?,cost=?,pred=NULL,pcost=NULL where id=?",
            (pred,pcost,row.id))
    else:
        if pred!=cls: errors += 1
        db.execute("update chars set pred=?,pcost=?,pocost=? where id=?",
            (pred,pcost,pocost,row.id))
    if total%100==0: 
        sys.stderr.write("%8d %8d (%.5f)\r"%(total,errors,errors*1.0/total))
    if total%1000==0: db.commit()
    if total>=options.limit: break

print
if not options.setclass:
    print "%8d %8d (%.5f)"%(total,errors,errors*1.0/total)

db.commit()
db.close()
del db
